import math
import torch
from torch.distributions.normal import Normal
import torch.nn.functional as F

def make_beta_schedule(schedule="linear", num_timesteps=1000, start=1e-5, end=1e-2):
    if schedule == "linear":
        betas = torch.linspace(start, end, num_timesteps)
    elif schedule == "const":
        betas = end * torch.ones(num_timesteps)
    elif schedule == "quad":
        betas = torch.linspace(start ** 0.5, end ** 0.5, num_timesteps) ** 2
    elif schedule == "jsd":
        betas = 1.0 / torch.linspace(num_timesteps, 1, num_timesteps)
    elif schedule == "sigmoid":
        betas = torch.linspace(-6, 6, num_timesteps)
        betas = torch.sigmoid(betas) * (end - start) + start
    elif schedule == "cosine" or schedule == "cosine_reverse":
        max_beta = 0.999
        cosine_s = 0.008
        betas = torch.tensor(
            [min(1 - (math.cos(((i + 1) / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2) / (
                    math.cos((i / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2), max_beta) for i in
             range(num_timesteps)])
    elif schedule == "cosine_anneal":
        betas = torch.tensor(
            [start + 0.5 * (end - start) * (1 - math.cos(t / (num_timesteps - 1) * math.pi)) for t in
             range(num_timesteps)])
    return betas

def extract(input, t, x):
    shape = x.shape
    out = torch.gather(input, 0, t.to(input.device))
    reshape = [t.shape[0]] + [1] * (len(shape) - 1)
    return out.reshape(*reshape)

# Forward functions
def q_sample(y, y_feature, variance, alphas_bar_sqrt, t, noise, prototype):
    """
    y_0_hat: prediction of pre-trained guidance classifier; can be extended to represent 
        any prior mean setting at timestep T.
    """
    noise_label = noise[:, :noise.shape[1]//2]
    noise_feature = noise[:, noise.shape[1]//2:]
    if noise is None:
        variance_label = variance[:, :variance.shape[1]//2]
        variance_feature = variance[:, variance.shape[1]//2:]
        noise_label = torch.normal(0., std = variance_label).to(y.device)
        noise_feature = torch.normal(0., std = variance_feature).to(y.device)
    sqrt_alpha_bar_t = extract(alphas_bar_sqrt, t, y)
    sqrt_alpha_bar_t_feature = extract(alphas_bar_sqrt, t, y_feature)
    # q(y_t | y_0, x)
    y_t = sqrt_alpha_bar_t * y + (1 - sqrt_alpha_bar_t) * noise_label
    y_feature_t = sqrt_alpha_bar_t_feature * y_feature + (1 - sqrt_alpha_bar_t_feature) * torch.matmul(noise_feature,prototype)

    y_t = F.softmax(y_t, dim=1)
    return y_t, y_feature_t

def y_0_reparam(model, y_feature, y, prototype, t, one_minus_alphas_bar_sqrt):
    """
    Obtain y_0 reparameterization from q(y_t | y_0), in which noise term is the eps_theta prediction.
    """
    device = next(model.parameters()).device
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)
    sqrt_one_minus_alpha_bar_t_feature = extract(one_minus_alphas_bar_sqrt, t, y_feature)
    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
    sqrt_alpha_bar_t_feature = (1 - sqrt_one_minus_alpha_bar_t_feature.square()).sqrt()
    eps_theta = model(y_feature, t)['noise'].to(device).detach()
    # y_0 reparameterization
    y_0_reparam = 1 / sqrt_alpha_bar_t * (
            y - eps_theta[:,:eps_theta.shape[1]//2] * sqrt_one_minus_alpha_bar_t).to(device)
    
    y_0_reparam = F.softmax(y_0_reparam, dim=1)
    # y_feature_0 reparameterization
    y_feature_0_reparam = 1 / sqrt_alpha_bar_t_feature * (
        y_feature - torch.matmul(eps_theta[:, eps_theta.shape[1]//2:],prototype) * sqrt_one_minus_alpha_bar_t_feature).to(device)
    return y_0_reparam, y_feature_0_reparam 


def p_sample(model, y, y_feature, t, alphas, one_minus_alphas_bar_sqrt, prototype, relation_matrix, k, variance_hat=None):
    """
    Reverse diffusion process sampling -- one time step.
    """
    device = next(model.parameters()).device
    z = torch.randn_like(y)
    t = torch.tensor([t]).to(device)
    alpha_t = extract(alphas, t, y)
    alpha_t_feature = extract(alphas, t, y_feature)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)
    sqrt_one_minus_alpha_bar_t_feature = extract(one_minus_alphas_bar_sqrt, t, y_feature)
    sqrt_one_minus_alpha_bar_t_m_1 = extract(one_minus_alphas_bar_sqrt, t - 1, y)
    sqrt_one_minus_alpha_bar_t_m_1_feature = extract(one_minus_alphas_bar_sqrt, t - 1, y_feature)
    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
    sqrt_alpha_bar_t_feature = (1 - sqrt_one_minus_alpha_bar_t_feature.square()).sqrt()
    sqrt_alpha_bar_t_m_1 = (1 - sqrt_one_minus_alpha_bar_t_m_1.square()).sqrt()
    sqrt_alpha_bar_t_m_1_feature = (1 - sqrt_one_minus_alpha_bar_t_m_1_feature.square()).sqrt()
    # y_t_m_1 posterior mean component coefficients
    gamma_0 = (1 - alpha_t) * sqrt_alpha_bar_t_m_1 / (sqrt_one_minus_alpha_bar_t.square())
    gamma_1 = (sqrt_one_minus_alpha_bar_t_m_1.square()) * (alpha_t.sqrt()) / (sqrt_one_minus_alpha_bar_t.square())

    eps_theta = model(y_feature, t)['noise'].to(device).detach()
    # y_0 reparameterization
    y_0_reparam = 1 / sqrt_alpha_bar_t * (
            y - eps_theta[:,:eps_theta.shape[1]//2] * sqrt_one_minus_alpha_bar_t).to(device)
    y_0_reparam = F.softmax(y_0_reparam, dim=1)
    # y_feature_0 reparameterization
    y_feature_0_reparam = 1 / sqrt_alpha_bar_t_feature * (
        y_feature - torch.matmul(eps_theta[:, eps_theta.shape[1]//2:],prototype) * sqrt_one_minus_alpha_bar_t_feature).to(device)
    
    # posterior mean
    y_t_m_1_hat = gamma_0 * y_0_reparam + gamma_1 * y
    y_feature_t_m_1_hat = gamma_0 * y_feature_0_reparam + gamma_1 * y_feature
    # variance_hat
    if variance_hat == None: 
        y_0_reparam_softmax = torch.softmax(y_0_reparam, dim=-1)
        label_entropy = -torch.sum(y_0_reparam_softmax * torch.log2(y_0_reparam_softmax+1e-8), dim=1)
        distributions = [Normal(0, label_entropy[i]) for i in range(len(label_entropy))]
        _, max_idx = torch.max(y_0_reparam_softmax , dim=-1,keepdim=True)

        sorted_indices = torch.argsort(relation_matrix, dim=1, descending=True)
        sorted_ranks = torch.argsort(sorted_indices, dim=1)

        dis = sorted_ranks[max_idx]
        variance_hat = [torch.exp(dist.log_prob(dis[i]))*k for i, dist in enumerate(distributions)]
        variance_hat= torch.stack(variance_hat).squeeze()
        variance_hat = variance_hat.clamp(0.1, 10.0)

    # posterior variance
    beta_t_hat = (sqrt_one_minus_alpha_bar_t_m_1.square()) / (sqrt_one_minus_alpha_bar_t.square()) * (1 - alpha_t)*variance_hat
    beta_t_hat_feature = (sqrt_one_minus_alpha_bar_t_m_1_feature.square()) / (sqrt_one_minus_alpha_bar_t_feature.square()) * (1 - alpha_t_feature)
    y_t_m_1 = y_t_m_1_hat.to(device) + beta_t_hat.sqrt().to(device) * z.to(device)
    y_t_m_1 = F.softmax(y_t_m_1, dim=1)
    # print(beta_t_hat.sqrt().shape)
    # print(torch.matmul(z,prototype).shape)
    y_feature_t_m_1 = y_feature_t_m_1_hat.to(device) + beta_t_hat_feature.sqrt().to(device) * torch.matmul(variance_hat*z,prototype).to(device)
    return y_t_m_1, y_feature_t_m_1

# Reverse function -- sample y_0 given y_1
def p_sample_t_1to0(model, y, y_feature, one_minus_alphas_bar_sqrt,prototype):
    device = next(model.parameters()).device
    t = torch.tensor([0]).to(device)  # corresponding to timestep 1 (i.e., t=1 in diffusion models)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)
    sqrt_one_minus_alpha_bar_t_feature = extract(one_minus_alphas_bar_sqrt, t, y_feature)
    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
    sqrt_alpha_bar_t_feature = (1 - sqrt_one_minus_alpha_bar_t_feature.square()).sqrt()
    eps_theta = model(y_feature, t)['noise'].to(device).detach()
    # y_0 reparameterization
    y_0_reparam = 1 / sqrt_alpha_bar_t * (
            y - eps_theta[:,:eps_theta.shape[1]//2] * sqrt_one_minus_alpha_bar_t).to(device)
    # y_feature_0 reparameterization
    y_0_reparam = F.softmax(y_0_reparam, dim=1)
    y_feature_0_reparam = 1 / sqrt_alpha_bar_t_feature * (
        y_feature - torch.matmul(eps_theta[:,eps_theta.shape[1]//2:],prototype) * sqrt_one_minus_alpha_bar_t_feature).to(device)
    y_t_m_1 = y_0_reparam.to(device)
    y_t_m_1 = F.softmax(y_t_m_1, dim=1)
    y_feature_t_m_1 = y_feature_0_reparam.to(device)
    return y_t_m_1, y_feature_t_m_1

def p_sample_loop(model, n_steps, alphas, one_minus_alphas_bar_sqrt,prototype, relation_matrix, k, variance_hat=None,
                  only_last_sample=False):
    num_t, y_p_seq, y_feature_p_seq= None, None, None
    device = next(model.parameters()).device
    z = torch.normal(0., std = variance_hat).to(device)
    #z = torch.randn_like(y).to(device)
    cur_y = z 
    cur_y_feature = torch.matmul(z,prototype)
    if only_last_sample:
        num_t = 1
    else:
        y_p_seq = [cur_y]
        y_feature_p_seq = [cur_y_feature]
    for t in reversed(range(1, n_steps)):
        y_t = cur_y
        y_feature_t = cur_y_feature
        cur_y, cur_y_feature = p_sample(model, y_t, y_feature_t, t, alphas, one_minus_alphas_bar_sqrt, prototype, relation_matrix, k, variance_hat)  # y_{t-1}
        if only_last_sample:
            num_t += 1
        else:
            y_p_seq.append(cur_y)
            y_feature_p_seq.append(cur_y_feature)
    if only_last_sample:
        assert num_t == n_steps
        y_0, y_feaure_0 = p_sample_t_1to0(model, cur_y, cur_y_feature, one_minus_alphas_bar_sqrt,prototype)
        return y_0, y_feaure_0
    else:
        assert len(y_p_seq) == n_steps
        y_0, y_feaure_0 = p_sample_t_1to0(model, y_p_seq[-1], y_feature_p_seq[-1], one_minus_alphas_bar_sqrt, prototype)
        y_p_seq.append(y_0)
        y_feature_p_seq.append(y_feaure_0)
        return y_p_seq, y_feature_p_seq

if __name__ == "__main__":
    y_batch = torch.rand([4,8])
    print(y_batch)